import torch.distributed
from data.kitti_data import KittiDataset
from data.nuscenes_data import NuscenesDataset
from data.apollo_data import ApolloDataset

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Sampler
import torch
import random
import os
import numpy as np

def make_data_loader(args, distributed=False):
    # Create the corresponding Dataset object based on the dataset name
    if args.dataset == 'kitti':
        val_seqs = ['06', '07']
        val_dataset = KittiDataset(args.root, val_seqs, args.npoints, args.voxel_size, args.data_list, 0.0)

        test_seqs = ['08', '09', '10']
        test_dataset = KittiDataset(args.root, test_seqs, args.npoints, args.voxel_size, args.data_list, 0.0)

        train_seqs = ['00', '01', '02', '03', '04', '05']
        train_dataset = KittiDataset(args.root, train_seqs, args.npoints, args.voxel_size, args.data_list, args.augment)

    elif args.dataset == 'nusc':
        val_seqs = ['val']
        val_dataset = NuscenesDataset(args.root, val_seqs, args.npoints, args.voxel_size, args.data_list, 0.0)

        test_seqs = ['test']
        test_dataset = NuscenesDataset(args.root, test_seqs, args.npoints, args.voxel_size, args.data_list, 0.0)

        train_seqs = ['train']
        train_dataset = NuscenesDataset(args.root, train_seqs, args.npoints, args.voxel_size, args.data_list, args.augment)

    elif args.dataset == 'apollo':
        val_seqs = ['val']
        data_root = os.path.join(args.root, 'TrainData')
        val_dataset = ApolloDataset(data_root, val_seqs, args.npoints, args.voxel_size, args.data_list, 0.0)

        test_seqs = ['test']
        data_root = os.path.join(args.root, 'TestData')
        test_dataset = ApolloDataset(data_root, test_seqs, args.npoints, args.voxel_size, args.data_list, 0.0)

        train_seqs = ['train']
        data_root = os.path.join(args.root, 'TrainData')
        train_dataset = ApolloDataset(data_root, train_seqs, args.npoints, args.voxel_size, args.data_list, args.augment)

    else:
        raise NotImplementedError('Dataset not implemented')

    # Use ClipDataset
    clip_length = args.clip_length  # Example: each clip contains 12 frames
    train_clip_dataset = ClipDataset(train_dataset, clip_length=clip_length)
    val_clip_dataset = ClipDataset(val_dataset, clip_length=clip_length)

    train_sampler = SequentialClipSampler(train_clip_dataset, batch_size=args.batch_size, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), shuffle=True) if distributed else None
    val_sampler = SequentialClipSampler(val_clip_dataset, batch_size=args.batch_size, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), shuffle=False) if distributed else None

    # Data loaders
    train_loader = DataLoader(train_clip_dataset, batch_size=args.batch_size, num_workers=4, 
                              sampler=train_sampler, pin_memory=True)
    val_loader = DataLoader(val_clip_dataset, batch_size=args.batch_size, num_workers=4, 
                            sampler=val_sampler, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers=4, 
                             shuffle=False, pin_memory=True)

    return train_loader, val_loader, test_loader

class ClipDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, clip_length=12):
        self.original_dataset = original_dataset
        self.clip_length = clip_length
        self.clips = self._create_clips()
        self.randg = np.random.RandomState()

    def _create_clips(self):
        clips = []
        start_idx = 0  # Global index
        for seq_idx, sequence_length in enumerate(self.original_dataset.data_len_sequence):
            seq_start = start_idx
            seq_end = seq_start + sequence_length
            while seq_start + self.clip_length <= seq_end:
                clips.append((seq_idx, seq_start, seq_start + self.clip_length))
                seq_start += self.clip_length

            if seq_start < seq_end:
                clips.append((seq_idx, seq_end - self.clip_length, seq_end))

            start_idx = seq_end
        return clips

    def __len__(self):
        return len(self.clips) * self.clip_length

    def __getitem__(self, idx):
        clip_idx = idx // self.clip_length
        frame_idx = idx % self.clip_length

        seq_idx, start, _ = self.clips[clip_idx]
        actual_index = start + frame_idx
        return self.original_dataset[actual_index]

    def reset_seed(self, seed=0):
        self.randg.seed(seed)


class SequentialClipSampler(DistributedSampler):
    def __init__(self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_clips = len(dataset.clips)
        self.clip_length = dataset.clip_length

    def __iter__(self):
        # Obtain the base indices from the parent DistributedSampler
        indices = list(super().__iter__())

        # Convert the indices into clip indices
        clip_indices = list(range(self.num_clips))

        # Shuffle clip indices if needed
        if self.shuffle:
            random.shuffle(clip_indices)

        # Pad the clip indices to ensure divisibility by batch_size
        if self.num_clips % self.batch_size != 0:
            clip_indices += clip_indices[:(self.batch_size - self.num_clips % self.batch_size)]

        # Divide data based on rank and num_replicas
        total_clips = len(clip_indices)
        clips_per_rank = total_clips // self.num_replicas
        start_idx = self.rank * clips_per_rank
        end_idx = start_idx + clips_per_rank
        clip_indices = clip_indices[start_idx:end_idx]

        # Generate batch indices
        final_indices = []
        for clip_start in range(0, len(clip_indices), self.batch_size):
            for frame_idx in range(self.clip_length):
                batch_indices = [clip_idx * self.clip_length + frame_idx for clip_idx in clip_indices[clip_start:clip_start + self.batch_size]]
                final_indices.extend(batch_indices)

        return iter(final_indices)

    def __len__(self):
        # Determine length based on the number of clips processed by the rank
        return (self.num_clips // self.num_replicas) * self.clip_length
